import tensorflow as tf
import os
import numpy as np
from matplotlib import pyplot as plt
mnist = tf.keras.datasets.fashion_mnist
(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0  # 归一化数据，变小输入特征更适合神经网络吸收

model = tf.keras.models.Sequential([
    tf.keras.layers.Flatten(),  # 拉直为一维数组
    tf.keras.layers.Dense(128, activation='relu'),  # 第一层网络128个神经元，用rulu函数
    tf.keras.layers.Dense(10, activation='softmax')  # 第二层网络10个神经元
])

model.compile(optimizer='adam',
              loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False),
              metrics=['sparse_categorical_accuracy'])

checkpoint_save_path = "18 checkpoint/mnist.ckpt"
if os.path.exists(checkpoint_save_path + '.index'):
    print('——————————————load model——————————————')
    model.load_weights(checkpoint_save_path)

cp_callback = tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_save_path,
                                                 save_weights_only=True,
                                                 save_best_only=True)

history = model.fit(x_train, y_train, batch_size=32, epochs=5, validation_data=(x_test, y_test), validation_freq=1,
                    callbacks=[cp_callback])

model.summary()

np.set_printoptions(threshold=np.inf)

print(model.trainable_variables)
file=open('./18 weights.txt','w')
for v in model.trainable_variables:
    file.write(str(v.name)+'\n')
    file.write(str(v.shape)+'\n')
    file.write(str(v.numpy())+'\n')
file.close()
##############################################################################

acc=history.history['sparse_categorical_accuracy']
val_acc=history.history['val_sparse_categorical_accuracy']
loss=history.history['loss']
val_loss = history.history['val_loss']

plt.subplot(1,2,1)
plt.plot(acc,label='Training Accuracy')
plt.plot(val_acc,label='Validation Accuracy')
plt.title('Training and Validation Accuracy')
plt.legend()

plt.subplot(1,2,2)
plt.plot(loss,label='Training Loss')
plt.plot(val_loss,label='Validation Loss')
plt.title('Training and Validation Loss')
plt.legend()
plt.show()
##############################################################################